package com.whoyx.jiebing.utils.nlp.lac_sdk;

import ai.djl.Model;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.paddlepaddle.engine.PpNDArray;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Utils;
import com.whoyx.jiebing.config.NlpFilePathConfig;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * @author stone
 */
public final class LacTranslator implements Translator<String, String[][]> {

    private final Map<String, String> word2idDict = new HashMap<>();
    private final Map<String, String> id2wordDict = new HashMap<>();
    private final Map<String, String> label2idDict = new HashMap<>();
    private final Map<String, String> id2labelDict = new HashMap<>();
    private final Map<String, String> wordReplaceDict = new HashMap<>();
    private String oovId;
    private String input;

    @Override
    public void prepare(TranslatorContext ctx) throws Exception {
        Model model = ctx.getModel();
        try (InputStream is = model.getArtifact("lac/word.dic").openStream()) {
            List<String> words = Utils.readLines(is, true);
            words.stream()
                    .filter(word -> (word != null && !"".equals(word)))
                    .forEach(
                            word -> {
                                String[] ws = word.split("	");
                                if (ws.length == 1) {
                                    word2idDict.put("", ws[0]); // 文字是key,id是value
                                    id2wordDict.put(ws[0], "");
                                } else {
                                    word2idDict.put(ws[1], ws[0]); // 文字是key,id是value
                                    id2wordDict.put(ws[0], ws[1]);
                                }
                            });
        }
        try (InputStream is = model.getArtifact("lac/tag.dic").openStream()) {
            List<String> words = Utils.readLines(is, true);
            words.stream()
                    .filter(word -> (word != null && !"".equals(word)))
                    .forEach(
                            word -> {
                                String[] ws = word.split("	");
                                label2idDict.put(ws[1], ws[0]); // 文字是key,id是value
                                id2labelDict.put(ws[0], ws[1]);
                            });
        }
        try (InputStream is = model.getArtifact("lac/q2b.dic").openStream()) {
            List<String> words = Utils.readLines(is, true);
            words.forEach(word -> {
                if (StringUtils.isBlank(word)) {
                    wordReplaceDict.put("　", " "); // 文字是key,id是value
                } else {
                    String[] ws = word.split("	");
                    if (ws.length == 1) {
                        if (ws[0] != null) {
                            wordReplaceDict.put(ws[0], ""); // 文字是key,id是value
                        } else {
                            wordReplaceDict.put("", ""); // 文字是key,id是value
                        }
                    } else {
                        wordReplaceDict.put(ws[0], ws[1]); // 文字是key,id是value
                    }
                }
            });
        }
        oovId = word2idDict.get("OOV");
    }

    @Override
    public String[][] processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
        String[] s = input.replace(" ", "").split("");

        List<String> sentOut = new ArrayList<>();
        List<String> tagsOut = new ArrayList<>();

        long[] array = ndList.get(0).toLongArray();
        List<String> tags = new ArrayList<>();

        for (long l : array) {
            tags.add(id2labelDict.get(String.valueOf(l)));
        }
        for (int i = 0; i < tags.size(); i++) {
            String tag = tags.get(i);
            String substring = tag.substring(0, tag.length() - 2);
            if (sentOut.size() == 0 || tag.endsWith("B") || tag.endsWith("S")) {
                sentOut.add(s[i]);
                tagsOut.add(substring);
                continue;
            }
            sentOut.set(sentOut.size() - 1, sentOut.get(sentOut.size() - 1) + s[i]);
            tagsOut.set(tagsOut.size() - 1, substring);
        }
        String[][] result = new String[2][sentOut.size()];

        result[0] = sentOut.toArray(new String[0]);
        result[1] = tagsOut.toArray(new String[0]);

        return result;
    }

    @Override
    public NDList processInput(TranslatorContext translatorContext, String s) throws Exception {
        this.input = s;

        NDManager manager = translatorContext.getNDManager();

        List<Long> lodList = new ArrayList<>(0);
        lodList.add(0L);
        List<Long> sh = tokenizeSingleString(manager, input, lodList);
        int size = lodList.get(lodList.size() - 1).intValue();
        long[] array = new long[size];
        for (int i = 0; i < size; i++) {
            if (sh.size() > i) {
                array[i] = sh.get(i);
            } else {
                array[i] = 0;
            }
        }
        NDArray ndArray = manager.create(array, new Shape(lodList.get(lodList.size() - 1), 1));

        ndArray.setName("words");
        long[][] lod = new long[1][2];
        lod[0][0] = 0;
        lod[0][1] = lodList.get(lodList.size() - 1);
        ((PpNDArray) ndArray).setLoD(lod);
        return new NDList(ndArray);
    }

    private List<Long> tokenizeSingleString(NDManager manager, String input, List<Long> lod) {
        List<Long> wordIds = new ArrayList<>();
        String[] s = input.replace(" ", "").split("");
        for (String word : s) {
            String newWord = wordReplaceDict.get(word);
            word = StringUtils.isBlank(newWord) ? word : newWord;
            String wordId = word2idDict.get(word);
            wordIds.add(Long.valueOf(StringUtils.isBlank(wordId) ? oovId : wordId));
        }
        lod.add((long) wordIds.size());
        return wordIds;
    }

    private NDArray stackInputs(List<NDList> tokenizedInputs, int index, String inputName) {
        NDArray stacked =
                NDArrays.stack(
                        tokenizedInputs.stream()
                                .map(list -> list.get(index).expandDims(0))
                                .collect(Collectors.toCollection(NDList::new)));
        stacked.setName(inputName);
        return stacked;
    }

    private NDArray tokenizeSingle(NDManager manager, String[] inputs, List<Integer> lod) {
        List<Integer> wordIds = new ArrayList<>();
        for (int i = 0; i < inputs.length; i++) {
            String input = inputs[i];
            String[] s = input.replace(" ", "").split("");
            for (String word : s) {
                String newWord = wordReplaceDict.get(word);
                word = StringUtils.isBlank(newWord) ? word : newWord;
                String wordId = word2idDict.get(word);
                wordIds.add(Integer.valueOf(StringUtils.isBlank(wordId) ? oovId : wordId));
            }
            lod.add(wordIds.size() + lod.get(i));
        }
        return manager.create(wordIds.stream().mapToLong(Long::valueOf).toArray());
    }

    @Override
    public Batchifier getBatchifier() {
        return null;
    }
}
